{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bs4 import BeautifulSoup\n",
    "import re\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.model_selection import KFold\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_string(string):\n",
    "    string = re.sub('[^A-Za-z0-9\\-\\/ ]+', ' ', string).split()\n",
    "    return [y.strip() for y in string]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_raw(filename):\n",
    "    with open(filename, 'r') as fopen:\n",
    "        entities = fopen.read()\n",
    "    soup = BeautifulSoup(entities, 'html.parser')\n",
    "    inside_tag = ''\n",
    "    texts, labels = [], []\n",
    "    for sentence in soup.prettify().split('\\n'):\n",
    "        if len(inside_tag):\n",
    "            splitted = process_string(sentence)\n",
    "            texts += splitted\n",
    "            labels += [inside_tag] * len(splitted)\n",
    "            inside_tag = ''\n",
    "        else:\n",
    "            if not sentence.find('</'):\n",
    "                pass\n",
    "            elif not sentence.find('<'):\n",
    "                inside_tag = sentence.split('>')[0][1:]\n",
    "            else:\n",
    "                splitted = process_string(sentence)\n",
    "                texts += splitted\n",
    "                labels += ['OTHER'] * len(splitted)\n",
    "    assert (len(texts)==len(labels)), \"length texts and labels are not same\"\n",
    "    print('len texts and labels: ', len(texts))\n",
    "    return texts,labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len texts and labels:  34012\n",
      "len texts and labels:  9249\n"
     ]
    }
   ],
   "source": [
    "train_texts, train_labels = parse_raw('data_train.txt')\n",
    "test_texts, test_labels = parse_raw('data_test.txt')\n",
    "train_texts += test_texts\n",
    "train_labels += test_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array(['OTHER', 'location', 'organization', 'person', 'quantity', 'time'],\n",
       "       dtype='<U12'), array([35613,  1536,  1592,  2358,  1336,   826]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(train_labels,return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('entities-bm-normalize-v3.txt','r') as fopen:\n",
    "    entities_bm = fopen.read().split('\\n')[:-1]\n",
    "entities_bm = [i.split() for i in entities_bm]\n",
    "entities_bm = [[i[0],'TIME' if i[0] in 'jam' else i[1]] for i in entities_bm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'KN'\n",
      "'KA'\n"
     ]
    }
   ],
   "source": [
    "replace_by = {'LOC':'location','PRN':'person','NORP':'organization','ORG':'organization','LAW':'law',\n",
    "             'EVENT':'OTHER','FAC':'organization','TIME':'time','O':'OTHER','ART':'person','DOC':'law'}\n",
    "for i in entities_bm:\n",
    "    try:\n",
    "        string = process_string(i[0])\n",
    "        if len(string):\n",
    "            train_labels.append(replace_by[i[1]])\n",
    "            train_texts.append(process_string(i[0])[0])  \n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        \n",
    "assert (len(train_texts)==len(train_labels)), \"length texts and labels are not same\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array(['OTHER', 'law', 'location', 'organization', 'person', 'quantity',\n",
       "        'time'], dtype='<U12'),\n",
       " array([47406,   107,  2010,  2435,  3913,  1336,  1240]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(train_labels,return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = {'PAD': 0,'NUM':1,'UNK':2}\n",
    "tag2idx = {'PAD': 0}\n",
    "char2idx = {'PAD': 0}\n",
    "word_idx = 3\n",
    "tag_idx = 1\n",
    "char_idx = 1\n",
    "\n",
    "def parse_XY(texts, labels):\n",
    "    global word2idx, tag2idx, char2idx, word_idx, tag_idx, char_idx\n",
    "    X, Y = [], []\n",
    "    for no, text in enumerate(texts):\n",
    "        text = text.lower()\n",
    "        tag = labels[no]\n",
    "        for c in text:\n",
    "            if c not in char2idx:\n",
    "                char2idx[c] = char_idx\n",
    "                char_idx += 1\n",
    "        if tag not in tag2idx:\n",
    "            tag2idx[tag] = tag_idx\n",
    "            tag_idx += 1\n",
    "        Y.append(tag2idx[tag])\n",
    "        if text not in word2idx:\n",
    "            word2idx[text] = word_idx\n",
    "            word_idx += 1\n",
    "        X.append(word2idx[text])\n",
    "    return X, np.array(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['OTHER', 'law', 'location', 'organization', 'person', 'quantity',\n",
       "       'time'], dtype='<U12')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y = parse_XY(train_texts, train_labels)\n",
    "idx2word={idx: tag for tag, idx in word2idx.items()}\n",
    "idx2tag = {i: w for w, i in tag2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_len = 50\n",
    "def iter_seq(x):\n",
    "    return np.array([x[i: i+seq_len] for i in range(0, len(x)-seq_len, 1)])\n",
    "\n",
    "def to_train_seq(*args):\n",
    "    return [iter_seq(x) for x in args]\n",
    "\n",
    "def generate_char_seq(batch):\n",
    "    x = [[len(idx2word[i]) for i in k] for k in batch]\n",
    "    maxlen = max([j for i in x for j in i])\n",
    "    temp = np.zeros((batch.shape[0],batch.shape[1],maxlen),dtype=np.int32)\n",
    "    for i in range(batch.shape[0]):\n",
    "        for k in range(batch.shape[1]):\n",
    "            for no, c in enumerate(idx2word[batch[i,k]]):\n",
    "                temp[i,k,-1-no] = char2idx[c]\n",
    "    return temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('crf-lstm-concat-bidirectional.json','w') as fopen:\n",
    "    fopen.write(json.dumps({'idx2tag':idx2tag,'idx2word':idx2word,\n",
    "           'word2idx':word2idx,'tag2idx':tag2idx,'char2idx':char2idx}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(58397, 50)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_seq, Y_seq = to_train_seq(X, Y)\n",
    "X_char_seq = generate_char_seq(X_seq)\n",
    "X_seq.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.utils import to_categorical\n",
    "Y_seq_3d = [to_categorical(i, num_classes=len(tag2idx)) for i in Y_seq]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "train_X, test_X, train_Y, test_Y, train_char, test_char = train_test_split(X_seq, Y_seq_3d, X_char_seq, \n",
    "                                                                           test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2.2\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "print(keras.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Model, Input\n",
    "from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional, Reshape, Concatenate, Lambda\n",
    "from keras_contrib.layers import CRF\n",
    "from keras import backend as K\n",
    "from keras.backend.tensorflow_backend import set_session\n",
    "set_session(tf.InteractiveSession())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_len = seq_len\n",
    "input_word = Input(shape=(None,))\n",
    "input_char = Input(shape=(None,None,))\n",
    "model_char = Embedding(input_dim=len(char2idx) + 1, output_dim=128)(input_char)\n",
    "s = K.shape(model_char)\n",
    "def backend_reshape(x):\n",
    "    return K.reshape(x, (s[0]*s[1],s[2],128))\n",
    "model_char = Lambda(backend_reshape)(model_char)\n",
    "model_char = Bidirectional(LSTM(units=50, return_sequences=True, recurrent_dropout=0.1))(model_char)\n",
    "def sliced(x):\n",
    "    return x[:,-1]\n",
    "model_char = Lambda(sliced)(model_char)\n",
    "def backend_reshape(x):\n",
    "    return K.reshape(x, (s[0],s[1],100))\n",
    "model_char = Lambda(backend_reshape)(model_char)\n",
    "model_word = Embedding(input_dim=len(word2idx) + 1, output_dim=64, mask_zero=True)(input_word)\n",
    "concated_word_char = Concatenate(-1)([model_char,model_word])\n",
    "model = Bidirectional(LSTM(units=50, return_sequences=True, recurrent_dropout=0.1))(concated_word_char)\n",
    "model = TimeDistributed(Dense(50, activation=\"relu\"))(model)\n",
    "crf = CRF(len(tag2idx))\n",
    "output = crf(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(inputs=[input_word, input_char], outputs=output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer=\"rmsprop\", loss=crf.loss_function, metrics=[crf.accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_2 (InputLayer)            (None, None, None)   0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding_1 (Embedding)         (None, None, None, 1 5120        input_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, None, 128)    0           embedding_1[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional_1 (Bidirectional) (None, None, 100)    71600       lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_2 (Lambda)               (None, 100)          0           bidirectional_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "input_1 (InputLayer)            (None, None)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "lambda_3 (Lambda)               (None, None, 100)    0           lambda_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_2 (Embedding)         (None, None, 64)     607936      input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, None, 164)    0           lambda_3[0][0]                   \n",
      "                                                                 embedding_2[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional_2 (Bidirectional) (None, None, 100)    86000       concatenate_1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "time_distributed_1 (TimeDistrib (None, None, 50)     5050        bidirectional_2[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "crf_1 (CRF)                     (None, None, 8)      488         time_distributed_1[0][0]         \n",
      "==================================================================================================\n",
      "Total params: 776,194\n",
      "Trainable params: 776,194\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"702pt\" viewBox=\"0.00 0.00 383.00 702.00\" width=\"383pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 698)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-698 379,-698 379,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 140671817582128 -->\n",
       "<g class=\"node\" id=\"node1\"><title>140671817582128</title>\n",
       "<polygon fill=\"none\" points=\"72,-657.5 72,-693.5 197,-693.5 197,-657.5 72,-657.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-671.8\">input_2: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140671817582072 -->\n",
       "<g class=\"node\" id=\"node2\"><title>140671817582072</title>\n",
       "<polygon fill=\"none\" points=\"54,-584.5 54,-620.5 215,-620.5 215,-584.5 54,-584.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-598.8\">embedding_1: Embedding</text>\n",
       "</g>\n",
       "<!-- 140671817582128&#45;&gt;140671817582072 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>140671817582128-&gt;140671817582072</title>\n",
       "<path d=\"M134.5,-657.313C134.5,-649.289 134.5,-639.547 134.5,-630.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-630.529 134.5,-620.529 131,-630.529 138,-630.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817583808 -->\n",
       "<g class=\"node\" id=\"node3\"><title>140671817583808</title>\n",
       "<polygon fill=\"none\" points=\"73,-511.5 73,-547.5 196,-547.5 196,-511.5 73,-511.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-525.8\">lambda_1: Lambda</text>\n",
       "</g>\n",
       "<!-- 140671817582072&#45;&gt;140671817583808 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>140671817582072-&gt;140671817583808</title>\n",
       "<path d=\"M134.5,-584.313C134.5,-576.289 134.5,-566.547 134.5,-557.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-557.529 134.5,-547.529 131,-557.529 138,-557.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817638856 -->\n",
       "<g class=\"node\" id=\"node4\"><title>140671817638856</title>\n",
       "<polygon fill=\"none\" points=\"0,-438.5 0,-474.5 269,-474.5 269,-438.5 0,-438.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-452.8\">bidirectional_1(lstm_1): Bidirectional(LSTM)</text>\n",
       "</g>\n",
       "<!-- 140671817583808&#45;&gt;140671817638856 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>140671817583808-&gt;140671817638856</title>\n",
       "<path d=\"M134.5,-511.313C134.5,-503.289 134.5,-493.547 134.5,-484.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-484.529 134.5,-474.529 131,-484.529 138,-484.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817646320 -->\n",
       "<g class=\"node\" id=\"node5\"><title>140671817646320</title>\n",
       "<polygon fill=\"none\" points=\"73,-365.5 73,-401.5 196,-401.5 196,-365.5 73,-365.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-379.8\">lambda_2: Lambda</text>\n",
       "</g>\n",
       "<!-- 140671817638856&#45;&gt;140671817646320 -->\n",
       "<g class=\"edge\" id=\"edge4\"><title>140671817638856-&gt;140671817646320</title>\n",
       "<path d=\"M134.5,-438.313C134.5,-430.289 134.5,-420.547 134.5,-411.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-411.529 134.5,-401.529 131,-411.529 138,-411.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817646600 -->\n",
       "<g class=\"node\" id=\"node7\"><title>140671817646600</title>\n",
       "<polygon fill=\"none\" points=\"73,-292.5 73,-328.5 196,-328.5 196,-292.5 73,-292.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-306.8\">lambda_3: Lambda</text>\n",
       "</g>\n",
       "<!-- 140671817646320&#45;&gt;140671817646600 -->\n",
       "<g class=\"edge\" id=\"edge5\"><title>140671817646320-&gt;140671817646600</title>\n",
       "<path d=\"M134.5,-365.313C134.5,-357.289 134.5,-347.547 134.5,-338.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-338.529 134.5,-328.529 131,-338.529 138,-338.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817581960 -->\n",
       "<g class=\"node\" id=\"node6\"><title>140671817581960</title>\n",
       "<polygon fill=\"none\" points=\"232,-365.5 232,-401.5 357,-401.5 357,-365.5 232,-365.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"294.5\" y=\"-379.8\">input_1: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140671817583304 -->\n",
       "<g class=\"node\" id=\"node8\"><title>140671817583304</title>\n",
       "<polygon fill=\"none\" points=\"214,-292.5 214,-328.5 375,-328.5 375,-292.5 214,-292.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"294.5\" y=\"-306.8\">embedding_2: Embedding</text>\n",
       "</g>\n",
       "<!-- 140671817581960&#45;&gt;140671817583304 -->\n",
       "<g class=\"edge\" id=\"edge6\"><title>140671817581960-&gt;140671817583304</title>\n",
       "<path d=\"M294.5,-365.313C294.5,-357.289 294.5,-347.547 294.5,-338.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"298,-338.529 294.5,-328.529 291,-338.529 298,-338.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817267632 -->\n",
       "<g class=\"node\" id=\"node9\"><title>140671817267632</title>\n",
       "<polygon fill=\"none\" points=\"130.5,-219.5 130.5,-255.5 298.5,-255.5 298.5,-219.5 130.5,-219.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"214.5\" y=\"-233.8\">concatenate_1: Concatenate</text>\n",
       "</g>\n",
       "<!-- 140671817646600&#45;&gt;140671817267632 -->\n",
       "<g class=\"edge\" id=\"edge7\"><title>140671817646600-&gt;140671817267632</title>\n",
       "<path d=\"M153.866,-292.313C163.987,-283.33 176.536,-272.193 187.587,-262.386\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"190.157,-264.784 195.313,-255.529 185.51,-259.549 190.157,-264.784\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140671817583304&#45;&gt;140671817267632 -->\n",
       "<g class=\"edge\" id=\"edge8\"><title>140671817583304-&gt;140671817267632</title>\n",
       "<path d=\"M275.134,-292.313C265.013,-283.33 252.464,-272.193 241.413,-262.386\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"243.49,-259.549 233.687,-255.529 238.843,-264.784 243.49,-259.549\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140670340079688 -->\n",
       "<g class=\"node\" id=\"node10\"><title>140670340079688</title>\n",
       "<polygon fill=\"none\" points=\"80,-146.5 80,-182.5 349,-182.5 349,-146.5 80,-146.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"214.5\" y=\"-160.8\">bidirectional_2(lstm_2): Bidirectional(LSTM)</text>\n",
       "</g>\n",
       "<!-- 140671817267632&#45;&gt;140670340079688 -->\n",
       "<g class=\"edge\" id=\"edge9\"><title>140671817267632-&gt;140670340079688</title>\n",
       "<path d=\"M214.5,-219.313C214.5,-211.289 214.5,-201.547 214.5,-192.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"218,-192.529 214.5,-182.529 211,-192.529 218,-192.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140670331073760 -->\n",
       "<g class=\"node\" id=\"node11\"><title>140670331073760</title>\n",
       "<polygon fill=\"none\" points=\"58,-73.5 58,-109.5 371,-109.5 371,-73.5 58,-73.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"214.5\" y=\"-87.8\">time_distributed_1(dense_1): TimeDistributed(Dense)</text>\n",
       "</g>\n",
       "<!-- 140670340079688&#45;&gt;140670331073760 -->\n",
       "<g class=\"edge\" id=\"edge10\"><title>140670340079688-&gt;140670331073760</title>\n",
       "<path d=\"M214.5,-146.313C214.5,-138.289 214.5,-128.547 214.5,-119.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"218,-119.529 214.5,-109.529 211,-119.529 218,-119.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140670340080920 -->\n",
       "<g class=\"node\" id=\"node12\"><title>140670340080920</title>\n",
       "<polygon fill=\"none\" points=\"175.5,-0.5 175.5,-36.5 253.5,-36.5 253.5,-0.5 175.5,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"214.5\" y=\"-14.8\">crf_1: CRF</text>\n",
       "</g>\n",
       "<!-- 140670331073760&#45;&gt;140670340080920 -->\n",
       "<g class=\"edge\" id=\"edge11\"><title>140670331073760-&gt;140670340080920</title>\n",
       "<path d=\"M214.5,-73.3129C214.5,-65.2895 214.5,-55.5475 214.5,-46.5691\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"218,-46.5288 214.5,-36.5288 211,-46.5289 218,-46.5288\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import SVG\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "\n",
    "SVG(model_to_dot(model).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 47301 samples, validate on 5256 samples\n",
      "Epoch 1/3\n",
      "47301/47301 [==============================] - 521s 11ms/step - loss: 0.1176 - acc: 0.9619 - val_loss: 0.0077 - val_acc: 0.9975\n",
      "Epoch 2/3\n",
      "47301/47301 [==============================] - 524s 11ms/step - loss: 0.0053 - acc: 0.9982 - val_loss: 0.0054 - val_acc: 0.9980\n",
      "Epoch 3/3\n",
      "47301/47301 [==============================] - 520s 11ms/step - loss: 0.0027 - acc: 0.9991 - val_loss: 0.0027 - val_acc: 0.9991\n"
     ]
    }
   ],
   "source": [
    "history = model.fit([train_X,train_char], np.array(train_Y), batch_size=32, epochs=3,\n",
    "                    validation_split=0.1, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5840/5840 [==============================] - 20s 3ms/step\n"
     ]
    }
   ],
   "source": [
    "predicted=model.predict([test_X,test_char],verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred2label(pred):\n",
    "    out = []\n",
    "    for pred_i in pred:\n",
    "        out_i = []\n",
    "        for p in pred_i:\n",
    "            p_i = np.argmax(p)\n",
    "            out_i.append(idx2tag[p_i])\n",
    "        out.append(out_i)\n",
    "    return out\n",
    "    \n",
    "pred_labels = pred2label(predicted)\n",
    "test_labels = pred2label(test_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       OTHER       1.00      1.00      1.00    236000\n",
      "         law       1.00      0.99      1.00       472\n",
      "    location       1.00      1.00      1.00     10303\n",
      "organization       1.00      1.00      1.00     12525\n",
      "      person       1.00      1.00      1.00     19828\n",
      "    quantity       1.00      1.00      1.00      6599\n",
      "        time       1.00      1.00      1.00      6273\n",
      "\n",
      " avg / total       1.00      1.00      1.00    292000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(np.array(test_labels).ravel(), np.array(pred_labels).ravel()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('crf-lstm-concat-bidirectional.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
